﻿using System;
using System.Collections.Generic;
using System.Diagnostics;
using System.Linq;
using System.Xml.Linq;
using System.IO;

namespace SqlMetalInclude
{
	/// <summary>
	/// SqlMetalInclude
	/// 
	/// Used to filter tables/views/stored procedures from a .dbml file generated from SqlMetal
	/// 
	/// By Jake Ginnivan
	/// </summary>
	static class Program
	{
		private static Arguments _args;
		private static readonly Dictionary<string, string> MemberRenames = new Dictionary<string, string>();
		private static readonly Dictionary<string, string> TypeRenames = new Dictionary<string, string>();

		/// <summary>
		/// Used to reference the table tag in the DBML xml schema.
		/// </summary>
		static int Main(string[] args)
		{
			Console.WriteLine("SqlMetalInclude");

			_args = new Arguments(args);

			if (_args["?"] == "true")
			{
				ShowHelp();
				return 0;
			}

			var dbml = _args["dbml"];
			if (dbml == null || _args["output"] == null)
			{
				Console.WriteLine("Missing dbml input or output argument...");
				ShowHelp();
				return 1;
			}

			if (!File.Exists(dbml))
			{
				Console.WriteLine("{0} doesn't exist", dbml);
				return 1;
			}

			if (_args["includeFile"] != null)
			{
				_args.Add("include", File.ReadAllLines(_args["includeFile"]).Aggregate((i, i2) => string.Join(",", i, i2)));
			}

			//Run specified mode
			if (_args["include"] != null && _args["exclude"] != null)
			{
				Console.WriteLine("Include and exclude arguments cannot both be specified");
				return 1;
			}
			if (_args["include"] != null)
			{
				var useColumnName = _args["useColumnName"] == "true";

				if (!RunInclude(useColumnName))
					return 1;

				Console.WriteLine();
			}
			else if (_args["exclude"] != null)
			{
				if (!RunExclude())
					return 1;

				Console.WriteLine();
			}

			return 0;
		}

		private static bool RunExclude()
		{
			throw new NotImplementedException();
		}

		private static void ShowHelp()
		{
			Console.WriteLine("Used to post-process a dbml file created by sqlMetal.exe");
			Console.WriteLine();
			Console.WriteLine("SqlMetalInclude -dbml:Inputfile -output:newDbmlFilename");
			Console.WriteLine("[-include:includeConfig] [-exclude:excludeConfig]");
			Console.WriteLine();
			Console.WriteLine("-dbml:Inputfile\tThe input dbml file generated by SqlMetal.exe");
			Console.WriteLine("-output:OutputFile\tThe filename of the new dbml file which has");
			Console.WriteLine("\t\tfiltered using the specified configuration");
			Console.WriteLine("-include:includeConfig");
			Console.WriteLine("\tSeparate the entities with comma's, ie table1,table2.");
			Console.WriteLine("\tTo rename a entity follow it directly with a = sign then the entity name");
			Console.WriteLine("\tTo rename both the entity name and list name the format is:");
			Console.WriteLine("\tSQLNAME=ListName/EntityName");
			Console.WriteLine();
			Console.WriteLine("Example usage:");
			Console.WriteLine("Assume database has a table and a view I want, these are:");
			Console.WriteLine("Accounts and vwContacts.");
			Console.WriteLine("SqlMetalInclude -dbml:sqlMetalOut.dbml -output:small.dbml ");
			Console.WriteLine("-include:vwContacts=Contacts/Contact,Accounts");
			Console.WriteLine("This will include the Accounts table and the vwContacts view ");
			Console.WriteLine("and rename it, and the end result when the code is generated will be:");
			Console.WriteLine("List<Contact> Contacts");
		}

		private static bool RunInclude(bool useColumnName)
		{
			var includes = new List<string>();
			includes.AddRange(_args["include"].Split(new[] { '|', ',' }, StringSplitOptions.RemoveEmptyEntries));

			var includedTables = new List<Table>();

			foreach (var splitInclude in includes.Select(include => include.Split(new[] { "=" }, StringSplitOptions.RemoveEmptyEntries)))
			{
				var table = new Table
								{
									TableName = splitInclude[0]
								};
				//If command line hasn't specified .dbo, add it
				if (!splitInclude[0].StartsWith("dbo."))
					splitInclude[0] = "dbo." + splitInclude[0];

				if (splitInclude.Length == 2)
				{
					var renames = splitInclude[1].Split(new[] { '/' }, StringSplitOptions.RemoveEmptyEntries);

					switch (renames.Length)
					{
						case 1:
							table.ListName = renames[0];
							break;
						case 2:
							table.ListName = renames[0];
							table.TypeName = renames[1];
							break;
					}
				}

				includedTables.Add(table);
			}

			var doc = XDocument.Load(_args["dbml"]);
			RemoveNotIncluded(includedTables, doc, useColumnName);

			UpdateReferences(doc);
			RemoveExtraAssociations(doc);
			RenameConflictingMembers(doc);

			doc.Save(_args["output"]);


			return true;
		}

		private static void RemoveExtraAssociations(XDocument doc)
		{
			//Get member names, we will remove any associations which references a member that doesn't exist, print a warning for each
			var typeNames = doc.Descendants(DbmlElements.Type).Attributes("Name").Select(a => a.Value).ToList();

			RemoveAssociation(doc, typeNames);
		}

		/// <summary>
		/// Renames conflicting type members. (members whose name is equal to the classes' name)
		/// </summary>
		/// <param name="doc">The document containing entity information.</param>
		/// <remarks></remarks>
		private static void RenameConflictingMembers(XDocument doc)
		{
			var types = doc.Descendants(DbmlElements.Type).ToList();

			foreach (var type in types)
			{
				RenameConflictingMembers(type);
			}
		}

		/// <summary>
		/// Renames conflicting members of the specified type. (members whose name is equal to the classes' name)
		/// </summary>
		/// <param name="type">The type.</param>
		/// <remarks></remarks>
		private static void RenameConflictingMembers(XElement type)
		{
			var typeNameAttribute = type.Attribute("Name");
			Debug.Assert(typeNameAttribute != null);

			var typeName = typeNameAttribute.Value;
			bool membernameNameExists = type.Descendants(DbmlElements.Column).Any(x =>
																					{
																						var nameAttribute = x.Attribute("Name");
																						var memberAttribute = x.Attribute("Member");
																						return (nameAttribute != null && (nameAttribute.Value.Equals("Name"))
																								|| memberAttribute != null && memberAttribute.Value.Equals("Name"));
																					});
			foreach (var member in type.Descendants(DbmlElements.Column))
			{
				var attribute = member.Attribute("Name");
				Debug.Assert(attribute != null);
				if (attribute.Value == typeName)
				{
					Console.WriteLine("Correcting member {0}", typeName);

					member.SetAttributeValue("Member", membernameNameExists
														? string.Format("@{0}", typeName)
														: "Name");
					member.SetAttributeValue("Storage", string.Format("_{0}", typeName));
				}
			}
		}


		private static void RemoveAssociation(XContainer doc, ICollection<string> typeNames)
		{
			var associations = doc.Descendants(DbmlElements.Association).ToList();
			foreach (var association in associations
				.Where(association =>
						{
							var xAttribute = association.Attribute("Type");
							Debug.Assert(xAttribute!=null);
							return !typeNames.Contains(xAttribute.Value);
						}))
			{
				var nameAttribute = association.Attribute("Name");
				Debug.Assert(nameAttribute != null);
				Debug.Assert(association.Parent != null);
				var parentNameAttribute = association.Parent.Attribute("Name");
				Debug.Assert(parentNameAttribute != null);

				Console.WriteLine("Removing association '{0}' from table '{1}' because it has not been included",
								  nameAttribute.Value,
								  parentNameAttribute.Value);

				association.Remove();
			}
		}

		/// <summary>
		/// Updates the references. This is to make sure the foreign keys point at any renamed tables.
		/// </summary>
		/// <param name="doc">The doc.</param>
		private static void UpdateReferences(XContainer doc)
		{
			foreach (var table in doc.Descendants(DbmlElements.Table))
			{
				var assocations = table.Descendants(DbmlElements.Association).ToList();

				foreach (var association in assocations)
				{
					//Get the important attributes
					var memberName = association.Attribute("Member");
					var typeName = association.Attribute("Type");

					Debug.Assert(memberName != null);
					Debug.Assert(typeName != null);

					//Check for members that have been renamed
					//3/12/08 Replace dbo. in case the dbo is specified because that is not valid for member and type names
					//31/5/12 Done for consistent naming only and does not have to do with integrity nor a working model
					String memberRename;
					if (MemberRenames.TryGetValue(memberName.Value, out memberRename))
						memberName.Value = memberRename.Replace("dbo.", string.Empty);

					String typeRename;
					if (TypeRenames.TryGetValue(typeName.Value, out typeRename))
						typeName.Value = typeRename.Replace("dbo.", string.Empty);

				}
			}
		}

		/// <summary>
		/// Removes any related tables that have not been included. And hashes the renames.
		/// </summary>
		/// <param name="tables">The member names.</param>
		/// <param name="doc">The dbml xml document</param>
		/// <param name="useColumnName"></param>
		/// <returns></returns>
		private static void RemoveNotIncluded(List<Table> tables, XContainer doc, bool useColumnName)
		{
			var tableDictionary = tables.ToDictionary(
				x => x.TableName.Split('.').Last(), //Compare table name only
				x => x,
				StringComparer.InvariantCultureIgnoreCase);

			var tableElements = doc.Descendants(DbmlElements.Table).ToList();

			int tablesRemoved = doc.Descendants(DbmlElements.Table).Count();
			foreach (var table in tableElements)
			{
				var tableNameAttribute = table.Attribute("Name");

				Debug.Assert(tableNameAttribute != null, "tableNameAttribute != null");
				
				var tableName = tableNameAttribute.Value.Split('.').Last(); //Compare table name only

				Table tableDefinition;
				if (!tableDictionary.TryGetValue(tableName, out tableDefinition))
				{
					Console.WriteLine("Table {0} removed.", tableName);
					table.Remove();
				}
				else
				{
					var listNameAttribute = table.Attribute("Member");
					Debug.Assert(listNameAttribute != null, "listAttribute != null");

					var listName = listNameAttribute.Value;
					if (tableDefinition.ListName != null && listName != tableDefinition.ListName)
					{
						Console.WriteLine("Table Member {0} will be renamed to {1}.", listName, tableDefinition.ListName);
						listNameAttribute.Value = tableDefinition.ListName;
						MemberRenames.Add(listName, tableDefinition.ListName);
					}

					var typeElement = table.Element(DbmlElements.Type);

					Debug.Assert(typeElement != null, "typeAttribute != null");

					var typeNameAttribute = typeElement.Attribute("Name");
					Debug.Assert(typeNameAttribute != null, "typeNameAttribute != null");

					var typeName = typeNameAttribute.Value;
					if (tableDefinition.TypeName != null && typeName != tableDefinition.TypeName)
					{
						Console.WriteLine("Type {0} will be renamed to {1}.", typeName, tableDefinition.TypeName);
						typeNameAttribute.Value = tableDefinition.TypeName;
						TypeRenames.Add(typeName, tableDefinition.TypeName);
					}

					if (useColumnName)
					{
						foreach (var column in typeElement.Elements(DbmlElements.Column))
						{
							var memberAttribute = column.Attribute("Member");

							if (memberAttribute != null)
								memberAttribute.Remove();
						}
					}
				}
			}

			tablesRemoved = tablesRemoved - (doc.Descendants(DbmlElements.Table).Count());

			Console.WriteLine("{0} tables removed", tablesRemoved);
		}
	}
}
